package xin.marcher.wind.migrate.sharding.algorithm;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.shardingsphere.sharding.api.sharding.complex.ComplexKeysShardingAlgorithm;
import org.apache.shardingsphere.sharding.api.sharding.complex.ComplexKeysShardingValue;

import java.util.Collection;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * 分库路由算法类
 */
public class DbShardingAlgorithm implements ComplexKeysShardingAlgorithm<Comparable<?>> {

    @Override
    public Collection<String> doSharding(Collection<String> dbs, ComplexKeysShardingValue<Comparable<?>> shardingValue) {
        Map<String, Collection<Comparable<?>>> columnNameAndShardingValuesMap = shardingValue.getColumnNameAndShardingValuesMap();

        for (String key : columnNameAndShardingValuesMap.keySet()) {
            Collection<Comparable<?>> comparables = columnNameAndShardingValuesMap.get(key);
            if (CollectionUtils.isNotEmpty(comparables)) {
                // 获取配置的路由策略
                return comparables.stream()
                        .map(comparable -> getActualDbName(String.valueOf(comparable), dbs))
                        .collect(Collectors.toSet());
            }
        }
        return null;
    }

    public String getActualDbName(String shardingValue, Collection<String> dbs) {
        // 获取路由字段的值的后三位
        String userIdSuffix = StringUtils.substring(shardingValue, shardingValue.length() - 3);
        // 获取路由字段的值的后三位
        int dbSuffix = userIdSuffix.hashCode() % dbs.size();
        for (String db : dbs) {
            if (db.endsWith(String.valueOf(dbSuffix))) {
                return db;
            }
        }
        return null;
    }

}